from defense.base_defense import BaseDefense
import yaml
import torch
from torchvision import transforms
import numpy as np
from diffusion.utils import dict2namespace, restore_checkpoint
from diffusion.utils import diff2clf, clf2diff
from diffusion.guided_diffusion.script_util import create_model_and_diffusion, model_and_diffusion_defaults

import time

standard_args = {
'att_max_timesteps': '30,50,200',
'att_num_diffusion_steps' :'1,1,5',
'def_max_timesteps' : '30,30,30,30,50,50,200,200',
'def_num_diffusion_steps' : '30,30,30,30,50,50,200,200',
'att_sampling_method' : 'ddpm',
'def_sampling_method' : 'ddpm',
'is_imagenet':True}
standard_args2 = {
'att_max_timesteps': '30,50,200',
'att_num_diffusion_steps' :'1,1,5',
'def_max_timesteps' : '30,50,200',
'def_num_diffusion_steps' : '1,1,5',
'att_sampling_method' : 'ddpm',
'def_sampling_method' : 'ddpm',
'is_imagenet':True}

class DiffusionDefense(BaseDefense):
    def __init__(self, device = None,iterations=1):
        super(DiffusionDefense, self).__init__(device,iterations)
        self.diffusion_model = self._create_model()
        self.max_timesteps, self.diffusion_steps = get_diffusion_params(
            standard_args['def_max_timesteps'],
            standard_args['def_num_diffusion_steps'])
        self.diffusion_forward = PurificationForward(
            diffusion=self.diffusion_model,
            max_timestep=self.max_timesteps,
            attack_steps=self.diffusion_steps,
            sampling_method=standard_args['def_sampling_method'],
            device=self.device
        )

    def _defense(self, x):
        """
        Apply the defense to the input tensor.
        """
        with torch.no_grad():
            x = transforms.Resize((256,256))(x)
            for _ in range(self.iterations):
                x = self.diffusion_forward(x)
            x = transforms.Resize((224,224))(x)
        return x
            
    
    def _create_model(self,):
        # Load the diffusion model
        with open('diffusion/diffusion_configs/imagenet.yml', 'r') as f:
            config = yaml.load(f, Loader=yaml.Loader)
        config = dict2namespace(config)
        model_config = model_and_diffusion_defaults()
        model_config.update(vars(config.model))
        diffusion, _ = create_model_and_diffusion(**model_config)
        diffusion.load_state_dict(torch.load('data/diffusion_uncond.pt', map_location='cpu'))
        diffusion.eval().to(self.device)
        return diffusion

    
def get_diffusion_params(max_timesteps, num_denoising_steps):
    max_timestep_list = [int(i) for i in max_timesteps.split(',')]
    num_denoising_steps_list = [int(i) for i in num_denoising_steps.split(',')]
    assert len(max_timestep_list) == len(num_denoising_steps_list)

    diffusion_steps = []
    for i in range(len(max_timestep_list)):
        diffusion_steps.append([i - 1 for i in range(max_timestep_list[i] // num_denoising_steps_list[i],
                               max_timestep_list[i] + 1, max_timestep_list[i] // num_denoising_steps_list[i])])
        max_timestep_list[i] = max_timestep_list[i] - 1

    return max_timestep_list, diffusion_steps

def get_beta_schedule(beta_start, beta_end, num_diffusion_timesteps):
    betas = np.linspace(
        beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64
    )
    assert betas.shape == (num_diffusion_timesteps,)
    return torch.from_numpy(betas).float()

class PurificationForward(torch.nn.Module):
    def __init__(self, diffusion, max_timestep, attack_steps, sampling_method, device):
        super().__init__()
        self.diffusion = diffusion
        self.betas = get_beta_schedule(1e-4, 2e-2, 1000).to(device)
        self.max_timestep = max_timestep
        self.attack_steps = attack_steps
        self.sampling_method = sampling_method
        assert sampling_method in ['ddim', 'ddpm']
        if self.sampling_method == 'ddim':
            self.eta = 0
        elif self.sampling_method == 'ddpm':
            self.eta = 1
        

    def compute_alpha(self, t):
        beta = torch.cat(
            [torch.zeros(1).to(self.betas.device), self.betas], dim=0)
        a = (1 - beta).cumprod(dim=0).index_select(0, t + 1).view(-1, 1, 1, 1)
        return a

    def get_noised_x(self, x, t):
        e = torch.randn_like(x)
        if type(t) == int:
            t = (torch.ones(x.shape[0]) * t).to(x.device).long()
        a = (1 - self.betas).cumprod(dim=0).index_select(0, t).view(-1, 1, 1, 1)
        x = x * a.sqrt() + e * (1.0 - a).sqrt()
        return x

    def denoising_process(self, x, seq):
        n = x.size(0)
        seq_next = [-1] + list(seq[:-1])
        xt = x
        total_time = 0.0
        for i, j in zip(reversed(seq), reversed(seq_next)):
            t_start = time.time()
            t = (torch.ones(n) * i).to(x.device)
            next_t = (torch.ones(n) * j).to(x.device)
            at = self.compute_alpha(t.long())
            at_next = self.compute_alpha(next_t.long())
            m_time = time.time()
            et = self.diffusion(xt, t)
            print(f"Diffusion model took {time.time() - m_time:.4f} seconds")
            
            et, _ = torch.split(et, 3, dim=1)
            x0_t = (xt - et * (1 - at).sqrt()) / at.sqrt()
            c1 = (
                self.eta * ((1 - at / at_next) *
                            (1 - at_next) / (1 - at)).sqrt()
            )
            c2 = ((1 - at_next) - c1 ** 2).sqrt()
            xt = at_next.sqrt() * x0_t + c1 * torch.randn_like(x) + c2 * et
            t_step = time.time() - t_start
            total_time += t_step
            print(f"Step t={i} took {t_step:.4f} seconds")

        avg_time = total_time / len(seq)
        print(f"\nDenoising process: {len(seq)} steps, avg per step: {avg_time:.4f} seconds\n",flush=True)
        return xt

    def forward(self, x):
        # diffusion part

        x_diff = clf2diff(x)
        for i in range(len(self.max_timestep)):
            noised_x = self.get_noised_x(x_diff, self.max_timestep[i])
            x_diff = self.denoising_process(noised_x, self.attack_steps[i])

        x_clf = diff2clf(x_diff)
        
        return x_clf